#!/usr/bin/env python

from __future__ import print_function, unicode_literals

import argparse
import logging
import os
import subprocess
import sys
import tempfile

logging.basicConfig(format='%(message)s', level=logging.INFO)


class RoundTripTask(object):
    def __init__(self, input_filename, action, swift_syntax_test,
                 skip_bad_syntax):
        assert action == '-round-trip-parse' or action == '-round-trip-lex'
        assert type(input_filename) == unicode
        assert type(swift_syntax_test) == str

        assert os.path.isfile(input_filename), \
            "Input file {} is not accessible!".format(input_filename)
        assert os.path.isfile(swift_syntax_test), \
            "{} tool is not accessible!".format(swift_syntax_test)
        self.input_filename = input_filename
        self.action = action
        self.swift_syntax_test = swift_syntax_test
        self.skip_bad_syntax = skip_bad_syntax
        self.returncode = None
        self.stdout = None
        self.stderr = None

    @property
    def test_command(self):
        return [self.swift_syntax_test, self.action,
                '-input-source-filename', self.input_filename]

    @property
    def diff_command(self):
        return ['/usr/bin/diff', '-u', self.input_filename, '-']

    def diff(self):
        logging.debug(' '.join(self.diff_command))
        diff = subprocess.Popen(self.diff_command, stdin=subprocess.PIPE,
                                stdout=subprocess.PIPE,
                                stderr=subprocess.PIPE)
        stdout, stderr = diff.communicate(self.stdout)
        if diff.returncode != 0:
            return stdout
        assert stdout == ''
        assert stderr == ''
        return None

    def run(self):
        command = self.test_command
        logging.debug(' '.join(command))
        self.output_file = tempfile.NamedTemporaryFile('w')
        self.stderr_file = tempfile.NamedTemporaryFile('w')

        process = subprocess.Popen(command, stdout=self.output_file,
                                   stderr=self.stderr_file)
        process.wait()
        self.returncode = process.returncode

        with open(self.output_file.name, 'r') as stdout_in:
            self.stdout = stdout_in.read()
        with open(self.stderr_file.name, 'r') as stderr_in:
            self.stderr = stderr_in.read()

        self.output_file.flush()
        self.stderr_file.flush()

        try:
            if self.returncode != 0:
                if self.skip_bad_syntax:
                    logging.warning('---===WARNING===--- Lex/parse had error'
                                    ' diagnostics, so not diffing. Skipping'
                                    ' this file due to -skip-bad-syntax.')
                    logging.error(' '.join(command))
                    return None
                else:
                    logging.error('---===ERROR===--- Lex/parse had error'
                                  ' diagnostics, so not diffing.')
                    logging.error(' '.join(command))
                    logging.error(self.stdout)
                    logging.error(self.stderr)
                    raise RuntimeError()
        finally:
            self.output_file.close()
            self.stderr_file.close()

        diff = self.diff()
        return diff


def swift_files_in_dir(d):
    swift_files = []
    for root, dirs, files in os.walk(d):
        for basename in files:
            if not basename.decode('utf-8').endswith('.swift'):
                continue
            abs_file = os.path.abspath(os.path.join(root, basename))
            swift_files.append(abs_file)
    return swift_files


def run_task(task):
    try:
        diff = task.run()
        if diff is not None:
            logging.error('---===ERROR===--- Diff failed!')
            logging.error(' '.join(task.test_command))
            logging.error(diff)
            logging.error('')
            return True
    except RuntimeError as e:
        logging.error(e.message)
        return True
    return False


def main():
    tool_description = '''
Checks for round-trip lex/parse/print compatibility.

Swift's syntax representation should be "full-fidelity", meaning that there is
a perfect representation of what is in the source. When printing a syntax tree
to a file, that file should be identical to the file that was
originally parsed.

This driver invokes swift-syntax-test using -round-trip-lex and
-round-trip-parse on .swift files and .swift files in directories.
'''
    parser = argparse.ArgumentParser(description=tool_description)
    parser.add_argument('--directory', '-d', action='append',
                        dest='input_directories', default=[],
                        help='Add a directory, searching for .swift files'
                             ' within')
    parser.add_argument('--file', '-f', action='append',
                        dest='individual_input_files', default=[],
                        help='Add an individual file to test')
    parser.add_argument('--swift-syntax-test', '-t', required=True,
                        dest='tool_path',
                        help='Absolute path to the swift-syntax-test tool')
    parser.add_argument('--skip-bad-syntax',
                        action='store_true',
                        default=False,
                        help="Skip files that caused lex or parse diagnostics"
                             " to be emitted")

    args = parser.parse_args()

    dir_listings = [swift_files_in_dir(d) for d in args.input_directories]
    all_input_files = [filename for dir_listing in dir_listings
                       for filename in dir_listing]
    all_input_files += args.individual_input_files
    all_input_files = [f.decode('utf-8') for f in all_input_files]

    if len(all_input_files) == 0:
        logging.error('No input files!')
        sys.exit(1)

    if not os.path.isfile(args.tool_path):
        raise RuntimeError("Couldn't find swift-syntax-test at {}"
                           .format(args.tool_path))

    lex_tasks = [RoundTripTask(filename, '-round-trip-lex', args.tool_path,
                               args.skip_bad_syntax)
                 for filename in all_input_files]
    parse_tasks = [RoundTripTask(filename, '-round-trip-parse', args.tool_path,
                                 args.skip_bad_syntax)
                   for filename in all_input_files]

    failed = reduce(lambda a, b: a and b,
                    map(run_task, lex_tasks + parse_tasks))
    sys.exit(1 if failed else 0)


if __name__ == '__main__':
    main()
